{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3f190e59",
   "metadata": {},
   "source": [
    "# Text-to-SQL\n",
    "\n",
    "In this tutorial, we’ll see how to implement an agent that leverages SQL using `smolagents`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4baf4579",
   "metadata": {},
   "source": [
    "> Let’s start with the golden question: why not keep it simple and use a standard text-to-SQL pipeline?\n",
    "\n",
    "A standard text-to-sql pipeline is brittle, since the generated SQL query can be incorrect. Even worse, the query could be incorrect, but not raise an error, instead giving some incorrect/useless outputs without raising an alarm.\n",
    "\n",
    "Instead, an agent system is able to critically inspect outputs and decide if the query needs to be changed or not, thus giving it a huge performance boost.\n",
    "\n",
    "Let’s build this agent!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54de0fad",
   "metadata": {},
   "source": [
    "## Installation\n",
    "We will install the necessary packages for this example. We are going to use `sqlalchemy` to create a database and `smolagents` to build our agent. We will use `litellm` for LLM inference and `agentops` for observability."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb5670f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install smolagents\n",
    "%pip install sqlalchemy\n",
    "%pip install agentops"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94cd3def",
   "metadata": {},
   "source": [
    "## Setting up the SQL Table"
   ]
  },
  {
   "cell_type": "code",
   "id": "b12ee7a2",
   "metadata": {},
   "outputs": [],
   "source": "from sqlalchemy import (\n    create_engine,\n    MetaData,\n    Table,\n    Column,\n    String,\n    Integer,\n    Float,\n    insert,\n    inspect,\n    text,\n)\nfrom smolagents import tool\nimport agentops\nfrom dotenv import load_dotenv\nimport os\nfrom smolagents import CodeAgent, LiteLLMModel\n\nengine = create_engine(\"sqlite:///:memory:\")\nmetadata_obj = MetaData()\n\n# create city SQL table\ntable_name = \"receipts\"\nreceipts = Table(\n    table_name,\n    metadata_obj,\n    Column(\"receipt_id\", Integer, primary_key=True),\n    Column(\"customer_name\", String(16), primary_key=True),\n    Column(\"price\", Float),\n    Column(\"tip\", Float),\n)\nmetadata_obj.create_all(engine)\n\nrows = [\n    {\"receipt_id\": 1, \"customer_name\": \"Alan Payne\", \"price\": 12.06, \"tip\": 1.20},\n    {\"receipt_id\": 2, \"customer_name\": \"Alex Mason\", \"price\": 23.86, \"tip\": 0.24},\n    {\"receipt_id\": 3, \"customer_name\": \"Woodrow Wilson\", \"price\": 53.43, \"tip\": 5.43},\n    {\"receipt_id\": 4, \"customer_name\": \"Margaret James\", \"price\": 21.11, \"tip\": 1.00},\n]\nfor row in rows:\n    stmt = insert(receipts).values(**row)\n    with engine.begin() as connection:\n        cursor = connection.execute(stmt)"
  },
  {
   "cell_type": "markdown",
   "id": "7d8b4d1c",
   "metadata": {},
   "source": [
    "## Build our Agent"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b861d52f",
   "metadata": {},
   "source": [
    "We need to create the table description first because the agent will use it to generate the SQL query."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61d5d41e",
   "metadata": {},
   "outputs": [],
   "source": [
    "inspector = inspect(engine)\n",
    "columns_info = [(col[\"name\"], col[\"type\"]) for col in inspector.get_columns(\"receipts\")]\n",
    "\n",
    "table_description = \"Columns:\\n\" + \"\\n\".join([f\"  - {name}: {col_type}\" for name, col_type in columns_info])\n",
    "print(table_description)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5272a433",
   "metadata": {},
   "source": [
    "Now we can create the tool that will be used by the agent to perform the SQL query."
   ]
  },
  {
   "cell_type": "code",
   "id": "84fd9b4d",
   "metadata": {},
   "outputs": [],
   "source": "@tool\ndef sql_engine(query: str) -> str:\n    \"\"\"\n    Allows you to perform SQL queries on the table. Returns a string representation of the result.\n    The table is named 'receipts'. Its description is as follows:\n        Columns:\n        - receipt_id: INTEGER\n        - customer_name: VARCHAR(16)\n        - price: FLOAT\n        - tip: FLOAT\n\n    Args:\n        query: The query to perform. This should be correct SQL.\n    \"\"\"\n    output = \"\"\n    with engine.connect() as con:\n        rows = con.execute(text(query))\n        for row in rows:\n            output += \"\\n\" + str(row)\n    return output"
  },
  {
   "cell_type": "markdown",
   "id": "84d7c4a7",
   "metadata": {},
   "source": [
    "Everything is ready to create the agent. We will use the `CodeAgent` class from `smolagents` to create the agent. `litellm` is used to create the model and the agent will use the `sql_engine` tool to perform the SQL query.\n",
    "\n",
    "`agentops` is used to track the agents. We will initialize it with our API key, which can be found in the [AgentOps settings](https://app.agentops.ai/settings/projects)."
   ]
  },
  {
   "cell_type": "code",
   "id": "0cc00a36",
   "metadata": {},
   "outputs": [],
   "source": "load_dotenv()\n\nos.environ[\"AGENTOPS_API_KEY\"] = os.getenv(\"AGENTOPS_API_KEY\", \"your_api_key_here\")\nos.environ[\"OPENAI_API_KEY\"] = os.getenv(\"OPENAI_API_KEY\", \"your_openai_api_key_here\")\n\nagentops.init(auto_start_session=False)\ntracer = agentops.start_trace(\n    trace_name=\"Text-to-SQL\", tags=[\"smolagents\", \"example\", \"text-to-sql\", \"agentops-example\"]\n)\nmodel = LiteLLMModel(\"openai/gpt-4o-mini\")"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "208ef8d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "agent = CodeAgent(\n",
    "    tools=[sql_engine],\n",
    "    model=model,\n",
    ")\n",
    "agent.run(\"Can you give me the name of the client who got the most expensive receipt?\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5cf2ac25",
   "metadata": {},
   "source": [
    "## Level 2: Table Joins"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22fc9aed",
   "metadata": {},
   "source": [
    "Now let’s make it more challenging! We want our agent to handle joins across multiple tables.\n",
    "\n",
    "So let’s make a second table recording the names of waiters for each receipt_id!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2c893cec",
   "metadata": {},
   "outputs": [],
   "source": [
    "table_name = \"waiters\"\n",
    "receipts = Table(\n",
    "    table_name,\n",
    "    metadata_obj,\n",
    "    Column(\"receipt_id\", Integer, primary_key=True),\n",
    "    Column(\"waiter_name\", String(16), primary_key=True),\n",
    ")\n",
    "metadata_obj.create_all(engine)\n",
    "\n",
    "rows = [\n",
    "    {\"receipt_id\": 1, \"waiter_name\": \"Corey Johnson\"},\n",
    "    {\"receipt_id\": 2, \"waiter_name\": \"Michael Watts\"},\n",
    "    {\"receipt_id\": 3, \"waiter_name\": \"Michael Watts\"},\n",
    "    {\"receipt_id\": 4, \"waiter_name\": \"Margaret James\"},\n",
    "]\n",
    "for row in rows:\n",
    "    stmt = insert(receipts).values(**row)\n",
    "    with engine.begin() as connection:\n",
    "        cursor = connection.execute(stmt)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5d486ca",
   "metadata": {},
   "source": [
    "Since we changed the table, we update the `SQLExecutorTool` with this table’s description to let the LLM properly leverage information from this table."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77c8e2a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "updated_description = \"\"\"Allows you to perform SQL queries on the table. Beware that this tool's output is a string representation of the execution output.\n",
    "It can use the following tables:\"\"\"\n",
    "\n",
    "inspector = inspect(engine)\n",
    "for table in [\"receipts\", \"waiters\"]:\n",
    "    columns_info = [(col[\"name\"], col[\"type\"]) for col in inspector.get_columns(table)]\n",
    "\n",
    "    table_description = f\"Table '{table}':\\n\"\n",
    "\n",
    "    table_description += \"Columns:\\n\" + \"\\n\".join([f\"  - {name}: {col_type}\" for name, col_type in columns_info])\n",
    "    updated_description += \"\\n\\n\" + table_description\n",
    "\n",
    "print(updated_description)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee51dad3",
   "metadata": {},
   "source": [
    "Now let's update the `SQLExecutorTool` with the updated description and run the agent again."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6defa5b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "sql_engine.description = updated_description\n",
    "\n",
    "agent = CodeAgent(\n",
    "    tools=[sql_engine],\n",
    "    model=model,\n",
    ")\n",
    "\n",
    "agent.run(\"Which waiter got more total money from tips?\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "739ffd3b",
   "metadata": {},
   "source": [
    "All done! Now we can end the agentops session with a \"Success\" state. You can also end the session with a \"Failure\" or \"Indeterminate\" state, where the \"Indeterminate\" state is used by default."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9c7de64",
   "metadata": {},
   "outputs": [],
   "source": [
    "agentops.end_trace(tracer, end_state=\"Success\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b272d348",
   "metadata": {},
   "source": [
    "You can view the session in the [AgentOps dashboard](https://app.agentops.ai/sessions) by clicking the link provided after ending the session."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "test",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}